from plat.plat_train import get_train_info, plat_train_val, plat_trains

"""
比较三种模型的效果
模型类型: GCN, GAT, GatedGCN
"""
if __name__ == '__main__':
    log_path = "../log/compare_model/"
    log_gcn = "gcn.log"
    log_gat = "gat.log"
    log_gatedgcn = "gatedgcn.log"
    train_dict = dict()
    val_dict = dict()
    train_dict['gcn'], val_dict['gcn'] = get_train_info(log_path + log_gcn)
    train_dict['gat'], val_dict['gat'] = get_train_info(log_path + log_gat)
    train_dict['gatedgcn'], val_dict['gatedgcn'] = get_train_info(log_path + log_gatedgcn)
    plat_train_val('gcn', train_dict['gcn'], val_dict['gcn'])
    plat_train_val('gat', train_dict['gat'], val_dict['gat'])
    plat_train_val('gatedgcn', train_dict['gatedgcn'], val_dict['gatedgcn'])
    plat_trains(train_dict)
    plat_trains(val_dict)
